{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
    "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
    "\n",
    "    !pip install -q gymnasium\n",
    "\n",
    "    !touch .setup_complete\n",
    "\n",
    "# This code creates a virtual display to draw game images on.\n",
    "# It will have no effect if your machine has a monitor.\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's make a TRPO!\n",
    "\n",
    "In this notebook we will write the code of the one Trust Region Policy Optimization.\n",
    "As usually, it contains a few different parts which we are going to reproduce.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "\n",
    "env = gym.make(\"Acrobot-v1\", render_mode=\"rgb_array\")\n",
    "env.reset()\n",
    "observation_shape = env.observation_space.shape\n",
    "n_actions = env.action_space.n\n",
    "\n",
    "print(\"Observation Space\", env.observation_space)\n",
    "print(\"Action Space\", env.action_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(env.render())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Defining a network\n",
    "\n",
    "With all it's complexity, at it's core TRPO is yet another policy gradient method.\n",
    "\n",
    "This essentially means we're actually training a stochastic policy $\\pi_\\theta \\left( a \\middle| s \\right)$.\n",
    "\n",
    "And yes, it's gonna be a neural network. So let's start by defining one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TRPOAgent(nn.Module):\n",
    "    def __init__(self, state_shape: Tuple[int], n_actions: int):\n",
    "        '''\n",
    "        Here you should define your model\n",
    "        You should have LOG-PROBABILITIES as output because you will need it to compute loss\n",
    "        We recommend that you start simple:\n",
    "        use 1-2 hidden layers with 100-500 units and relu for the first try\n",
    "        '''\n",
    "        super().__init__()\n",
    "\n",
    "        assert isinstance(state_shape, tuple)\n",
    "        assert len(state_shape) == 1\n",
    "        input_dim = state_shape[0]\n",
    "        \n",
    "        # Prepare your model here.\n",
    "        <YOUR CODE>\n",
    "\n",
    "    def forward(self, states: torch.Tensor):\n",
    "        \"\"\"\n",
    "        takes agent's observation, returns log-probabilities\n",
    "        :param state_t: a batch of states, shape = [batch_size, state_shape]\n",
    "        \"\"\"\n",
    "\n",
    "        # Use your network to compute log_probs for the given states.\n",
    "        <YOUR CODE>\n",
    "        \n",
    "        return log_probs\n",
    "\n",
    "    def get_log_probs(self, states: torch.Tensor):\n",
    "        '''\n",
    "        Log-probs for training\n",
    "        '''\n",
    "        return self.forward(states)\n",
    "\n",
    "    def get_probs(self, states: torch.Tensor):\n",
    "        '''\n",
    "        Probs for interaction\n",
    "        '''\n",
    "        return torch.exp(self.forward(states))\n",
    "\n",
    "    def act(self, obs: np.ndarray, sample: bool = True):\n",
    "        '''\n",
    "        Samples action from policy distribution (sample = True) or takes most likely action (sample = False)\n",
    "        :param: obs - single observation vector\n",
    "        :param sample: if True, samples from \\pi, otherwise takes most likely action\n",
    "        :returns: action (single integer) and probabilities for all actions\n",
    "        '''\n",
    "\n",
    "        with torch.no_grad():\n",
    "            probs = self.get_probs(torch.tensor(obs[np.newaxis], dtype=torch.float32)).numpy()\n",
    "\n",
    "        if sample:\n",
    "            action = int(np.random.choice(n_actions, p=probs[0]))\n",
    "        else:\n",
    "            action = int(np.argmax(probs))\n",
    "\n",
    "        return action, probs[0]\n",
    "\n",
    "\n",
    "agent = TRPOAgent(observation_shape, n_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if log-probabilities satisfies all the requirements\n",
    "log_probs = agent.get_log_probs(torch.tensor(env.reset()[0][np.newaxis], dtype=torch.float32))\n",
    "assert (\n",
    "    isinstance(log_probs, torch.Tensor) and\n",
    "    log_probs.requires_grad\n",
    "), \"log_probs must be a torch.Tensor with grad\"\n",
    "assert log_probs.shape == (1, n_actions)\n",
    "sums = torch.exp(log_probs).sum(dim=1)\n",
    "assert torch.allclose(sums, torch.ones_like(sums))\n",
    "\n",
    "# Demo use\n",
    "print(\"sampled:\", [agent.act(env.reset()[0]) for _ in range(5)])\n",
    "print(\"greedy:\", [agent.act(env.reset()[0], sample=False) for _ in range(5)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Flat parameters operations\n",
    "\n",
    "We are going to use it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_flat_params_from(model):\n",
    "    params = [torch.ravel(param.detach()) for param in model.parameters()]\n",
    "    flat_params = torch.cat(params)\n",
    "    return flat_params\n",
    "\n",
    "\n",
    "def set_flat_params_to(model, flat_params):\n",
    "    prev_ind = 0\n",
    "    for param in model.parameters():\n",
    "        flat_size = int(np.prod(list(param.shape)))\n",
    "        param.data.copy_(\n",
    "            flat_params[prev_ind:prev_ind + flat_size].reshape(param.shape)\n",
    "        )\n",
    "        prev_ind += flat_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute cumulative reward just like you did in vanilla REINFORCE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.signal\n",
    "\n",
    "\n",
    "def get_cumulative_returns(r, gamma=1):\n",
    "    \"\"\"\n",
    "    Computes cumulative discounted rewards given immediate rewards\n",
    "    G_i = r_i + gamma*r_{i+1} + gamma^2*r_{i+2} + ...\n",
    "    Also known as R(s,a).\n",
    "    \"\"\"\n",
    "    r = np.array(r)\n",
    "    assert r.ndim >= 1\n",
    "    return scipy.signal.lfilter([1], [1, -gamma], r[::-1], axis=0)[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simple demo on rewards [0,0,1,0,0,1]\n",
    "get_cumulative_returns([0, 0, 1, 0, 0, 1], gamma=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Rollout**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout(env, agent, max_pathlength=2500, n_timesteps=50000):\n",
    "    \"\"\"\n",
    "    Generate rollouts for training.\n",
    "    :param: env - environment in which we will make actions to generate rollouts.\n",
    "    :param: act - the function that can return policy and action given observation.\n",
    "    :param: max_pathlength - maximum size of one path that we generate.\n",
    "    :param: n_timesteps - total sum of sizes of all pathes we generate.\n",
    "    \"\"\"\n",
    "    paths = []\n",
    "\n",
    "    total_timesteps = 0\n",
    "    while total_timesteps < n_timesteps:\n",
    "        obervations, actions, rewards, action_probs = [], [], [], []\n",
    "        obervation, _ = env.reset()\n",
    "        for _ in range(max_pathlength):\n",
    "            action, policy = agent.act(obervation)\n",
    "            obervations.append(obervation)\n",
    "            actions.append(action)\n",
    "            action_probs.append(policy)\n",
    "            obervation, reward, terminated, truncated, _ = env.step(action)\n",
    "            rewards.append(reward)\n",
    "            total_timesteps += 1\n",
    "            if terminated or truncated or total_timesteps >= n_timesteps:\n",
    "                path = {\n",
    "                    \"observations\": np.array(obervations),\n",
    "                    \"policy\": np.array(action_probs),\n",
    "                    \"actions\": np.array(actions),\n",
    "                    \"rewards\": np.array(rewards),\n",
    "                    \"cumulative_returns\": get_cumulative_returns(rewards),\n",
    "                }\n",
    "                paths.append(path)\n",
    "                break\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "paths = rollout(env, agent, max_pathlength=5, n_timesteps=100)\n",
    "pprint(paths[-1])\n",
    "\n",
    "assert (paths[0]['policy'].shape == (5, n_actions))\n",
    "assert (paths[0]['cumulative_returns'].shape == (5,))\n",
    "assert (paths[0]['rewards'].shape == (5,))\n",
    "assert (paths[0]['observations'].shape == (5,) + observation_shape)\n",
    "assert (paths[0]['actions'].shape == (5,))\n",
    "\n",
    "print(\"It's ok\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Auxiliary functions\n",
    "\n",
    "Now let's define the loss functions and something else for actual TRPO training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The surrogate reward should be:\n",
    "$$J_{surr}= {1 \\over N} \\sum\\limits_{i=1}^N \\frac{\\pi_{\\theta}(s_i, a_i)}{\\pi_{\\theta_{old}}(s_i, a_i)}A_{\\theta_{old}(s_i, a_i)}$$\n",
    "\n",
    "For simplicity, in this assignment we are going to use cumulative rewards instead of advantage:\n",
    "$$J'_{surr}= {1 \\over N} \\sum\\limits_{i=1}^N \\frac{\\pi_{\\theta}(s_i, a_i)}{\\pi_{\\theta_{old}}(s_i, a_i)}G_{\\theta_{old}(s_i, a_i)}$$\n",
    "\n",
    "Since we want to maximize the reward, we are going to minimize the corresponding surrogate loss:\n",
    "$$ L_{surr} = - J'_{surr} $$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_loss(agent, observations, actions, cumulative_returns, old_probs):\n",
    "    \"\"\"\n",
    "    Computes TRPO objective\n",
    "    :param: observations - batch of observations [timesteps x state_shape]\n",
    "    :param: actions - batch of actions [timesteps]\n",
    "    :param: cumulative_returns - batch of cumulative returns [timesteps]\n",
    "    :param: old_probs - batch of probabilities computed by old network [timesteps x num_actions]\n",
    "    :returns: scalar value of the objective function\n",
    "    \"\"\"\n",
    "    batch_size = observations.shape[0]\n",
    "    probs_all = agent.get_probs(observations)\n",
    "\n",
    "    probs_for_actions = probs_all[torch.arange(batch_size), actions]\n",
    "    old_probs_for_actions = old_probs[torch.arange(batch_size), actions]\n",
    "\n",
    "    # Compute surrogate loss, aka importance-sampled policy gradient\n",
    "    <YOUR CODE>\n",
    "\n",
    "    assert loss.ndim == 0\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can ascend these gradients as long as our $\\pi_\\theta(a|s)$ satisfies the constraint\n",
    "$$\\mathbb{E}_{s,\\pi_{\\theta_{t}}} \\Big[ \\operatorname{KL} \\left( \\pi_{\\theta_{t}} (s) \\:\\|\\: \\pi_{\\theta_{t+1}} (s) \\right) \\Big] < \\alpha$$\n",
    "\n",
    "\n",
    "where\n",
    "\n",
    "$$\\operatorname{KL} \\left( p \\| q \\right) = \\mathbb{E}_p \\log \\left( \\frac p q \\right)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_kl(agent, observations, actions, cumulative_returns, old_probs):\n",
    "    \"\"\"\n",
    "    Computes KL-divergence between network policy and old policy\n",
    "    :param: observations - batch of observations [timesteps x state_shape]\n",
    "    :param: actions - batch of actions [timesteps]\n",
    "    :param: cumulative_returns - batch of cumulative returns [timesteps] (we don't need it actually)\n",
    "    :param: old_probs - batch of probabilities computed by old network [timesteps x num_actions]\n",
    "    :returns: scalar value of the KL-divergence\n",
    "    \"\"\"\n",
    "    batch_size = observations.shape[0]\n",
    "    log_probs_all = agent.get_log_probs(observations)\n",
    "    probs_all = torch.exp(log_probs_all)\n",
    "\n",
    "    # Compute Kullback-Leibler divergence (see formula above).\n",
    "    # Note: you need to sum KL and entropy over all actions, not just the ones agent took.\n",
    "    # You will also need to compute max KL over all timesteps.\n",
    "    old_log_probs = torch.log(old_probs + 1e-10)\n",
    "\n",
    "    <YOUR CODE>\n",
    "\n",
    "    assert kl.ndim == 0\n",
    "    assert (kl > -0.0001).all() and (kl < 10000).all()\n",
    "    return kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_entropy(agent, observations):\n",
    "    \"\"\"\n",
    "    Computes entropy of the network policy\n",
    "    :param: observations - batch of observations\n",
    "    :returns: scalar value of the entropy\n",
    "    \"\"\"\n",
    "\n",
    "    observations = torch.tensor(observations, dtype=torch.float32)\n",
    "\n",
    "    log_probs_all = agent.get_log_probs(observations)\n",
    "    probs_all = torch.exp(log_probs_all)\n",
    "\n",
    "    entropy = (-probs_all * log_probs_all).sum(dim=1).mean(dim=0)\n",
    "\n",
    "    assert entropy.ndim == 0\n",
    "    return entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Linear search**\n",
    "\n",
    "TRPO in its core involves ascending surrogate policy gradient constrained by KL divergence.\n",
    "\n",
    "In order to enforce this constraint, we're gonna use linesearch. You can find out more about it [here](https://en.wikipedia.org/wiki/Linear_search)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linesearch(f, x: torch.Tensor, fullstep: torch.Tensor, max_kl: float, max_backtracks: int = 10, backtrack_coef: float = 0.5):\n",
    "    \"\"\"\n",
    "    Linesearch finds the best parameters of neural networks in the direction of fullstep contrainted by KL divergence.\n",
    "    :param: f - function that returns loss, kl and arbitrary third component.\n",
    "    :param: x - old parameters of neural network.\n",
    "    :param: fullstep - direction in which we make search.\n",
    "    :param: max_kl - constraint of KL divergence.\n",
    "    :returns:\n",
    "    \"\"\"\n",
    "    loss, _, = f(x)\n",
    "    for stepfrac in backtrack_coef**np.arange(max_backtracks):\n",
    "        xnew = x + stepfrac * fullstep\n",
    "        new_loss, kl = f(xnew)\n",
    "        if kl <= max_kl and new_loss < loss:\n",
    "            x = xnew\n",
    "            loss = new_loss\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Conjugate gradients**\n",
    "\n",
    "Since TRPO includes contrainted optimization, we will need to solve $A x = b$ using conjugate gradients.\n",
    "\n",
    "In general, CG is an algorithm that solves $A x = b$ where $A$ is positive-defined. $A$ is the Hessian matrix so $A$ is positive-defined. You can find out more about CG [here](https://en.wikipedia.org/wiki/Conjugate_gradient_method)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):\n",
    "    \"\"\"\n",
    "    This method solves system of equation Ax=b using an iterative method called conjugate gradients\n",
    "    :f_Ax: function that returns Ax\n",
    "    :b: targets for Ax\n",
    "    :cg_iters: how many iterations this method should do\n",
    "    :residual_tol: epsilon for stability\n",
    "    \"\"\"\n",
    "    p = b.clone()\n",
    "    r = b.clone()\n",
    "    x = torch.zeros_like(b)\n",
    "    rdotr = torch.sum(r*r)\n",
    "    for i in range(cg_iters):\n",
    "        z = f_Ax(p)\n",
    "        v = rdotr / (torch.sum(p*z) + 1e-8)\n",
    "        x += v * p\n",
    "        r -= v * z\n",
    "        newrdotr = torch.sum(r*r)\n",
    "        mu = newrdotr / (rdotr + 1e-8)\n",
    "        p = r + mu * p\n",
    "        rdotr = newrdotr\n",
    "        if rdotr < residual_tol:\n",
    "            break\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code validates conjugate gradients\n",
    "A = np.random.rand(8, 8)\n",
    "A = A.T @ A\n",
    "\n",
    "\n",
    "def f_Ax(x):\n",
    "    return torch.ravel(torch.tensor(A, dtype=torch.float32) @ x.reshape(-1, 1))\n",
    "\n",
    "\n",
    "b = np.random.rand(8)\n",
    "w = (np.linalg.inv(A.T @ A) @ A.T @ b.reshape(-1, 1)).reshape(-1)\n",
    "\n",
    "print(w)\n",
    "print(conjugate_gradient(f_Ax, torch.tensor(b, dtype=torch.float32)).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: training\n",
    "In this section we construct the whole update step function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_step(agent, observations, actions, cumulative_returns, old_probs, max_kl):\n",
    "    \"\"\"\n",
    "    This function does the TRPO update step\n",
    "    :param: observations - batch of observations\n",
    "    :param: actions - batch of actions\n",
    "    :param: cumulative_returns - batch of cumulative returns\n",
    "    :param: old_probs - batch of probabilities computed by old network\n",
    "    :param: max_kl - controls how big KL divergence may be between old and new policy every step.\n",
    "    :returns: KL between new and old policies and the value of the loss function.\n",
    "    \"\"\"\n",
    "\n",
    "    # Here we prepare the information\n",
    "    observations = torch.tensor(observations, dtype=torch.float32)\n",
    "    actions = torch.tensor(actions, dtype=torch.int64)\n",
    "    cumulative_returns = torch.tensor(cumulative_returns, dtype=torch.float32)\n",
    "    old_probs = torch.tensor(old_probs, dtype=torch.float32)\n",
    "\n",
    "    # Here we compute gradient of the loss function\n",
    "    loss = get_loss(agent, observations, actions, cumulative_returns, old_probs)\n",
    "    grads = torch.autograd.grad(loss, agent.parameters())\n",
    "    loss_grad = torch.cat([torch.ravel(grad.detach()) for grad in grads])\n",
    "\n",
    "    def Fvp(v):\n",
    "        # Here we compute Fx to do solve Fx = g using conjugate gradients\n",
    "        # We actually do here a couple of tricks to compute it efficiently\n",
    "\n",
    "        kl = get_kl(agent, observations, actions, cumulative_returns, old_probs)\n",
    "\n",
    "        grads = torch.autograd.grad(kl, agent.parameters(), create_graph=True)\n",
    "        flat_grad_kl = torch.cat([grad.reshape(-1) for grad in grads])\n",
    "\n",
    "        kl_v = (flat_grad_kl * v).sum()\n",
    "        grads = torch.autograd.grad(kl_v, agent.parameters())\n",
    "        flat_grad_grad_kl = torch.cat([torch.ravel(grad) for grad in grads]).detach()\n",
    "\n",
    "        return flat_grad_grad_kl + v * 0.1\n",
    "\n",
    "    # Here we solve Fx = g system using conjugate gradients\n",
    "    stepdir = conjugate_gradient(Fvp, -loss_grad, 10)\n",
    "\n",
    "    # Here we compute the initial vector to do linear search\n",
    "    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)\n",
    "\n",
    "    lm = torch.sqrt(shs / max_kl)\n",
    "    fullstep = stepdir / lm[0]\n",
    "\n",
    "    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)\n",
    "\n",
    "    # Here we get the start point\n",
    "    prev_params = get_flat_params_from(agent)\n",
    "\n",
    "    def get_loss_kl(params):\n",
    "        # Helper for linear search\n",
    "        set_flat_params_to(agent, params)\n",
    "        return [\n",
    "            get_loss(agent, observations, actions, cumulative_returns, old_probs),\n",
    "            get_kl(agent, observations, actions, cumulative_returns, old_probs),\n",
    "        ]\n",
    "\n",
    "    # Here we find our new parameters\n",
    "    new_params = linesearch(get_loss_kl, prev_params, fullstep, max_kl)\n",
    "\n",
    "    # And we set it to our network\n",
    "    set_flat_params_to(agent, new_params)\n",
    "\n",
    "    return get_loss_kl(new_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5: Main TRPO loop\n",
    "\n",
    "Here we will train our network!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from itertools import count\n",
    "\n",
    "# TRPO hyperparameter; controls how big KL divergence may be between the old and the new policy at every step.\n",
    "max_kl = 0.01\n",
    "numeptotal = 0  # Number of episodes we have completed so far.\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "for i in count(1):\n",
    "    print(\"\\n********** Iteration %i ************\" % i)\n",
    "\n",
    "    # Generating paths.\n",
    "    print(\"Rollout\")\n",
    "    paths = rollout(env, agent)\n",
    "    print(\"Made rollout\")\n",
    "\n",
    "    # Updating policy.\n",
    "    observations = np.concatenate([path[\"observations\"] for path in paths])\n",
    "    actions = np.concatenate([path[\"actions\"] for path in paths])\n",
    "    returns = np.concatenate([path[\"cumulative_returns\"] for path in paths])\n",
    "    old_probs = np.concatenate([path[\"policy\"] for path in paths])\n",
    "\n",
    "    loss, kl = update_step(agent, observations, actions, returns, old_probs, max_kl)\n",
    "\n",
    "    # Report current progress\n",
    "    episode_rewards = np.array([path[\"rewards\"].sum() for path in paths])\n",
    "\n",
    "    stats = {}\n",
    "    numeptotal += len(episode_rewards)\n",
    "    stats[\"Total number of episodes\"] = numeptotal\n",
    "    stats[\"Average sum of rewards per episode\"] = episode_rewards.mean()\n",
    "    stats[\"Std of rewards per episode\"] = episode_rewards.std()\n",
    "    stats[\"Time elapsed\"] = \"%.2f mins\" % ((time.time() - start_time)/60.)\n",
    "    stats[\"KL between old and new distribution\"] = kl.data.numpy()\n",
    "    stats[\"Entropy\"] = get_entropy(agent, observations).data.numpy()\n",
    "    stats[\"Surrogate loss\"] = loss.data.numpy()\n",
    "    for k, v in stats.items():\n",
    "        print(k + \": \" + \" \" * (40 - len(k)) + str(v))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework option I: better sampling (10+pts)\n",
    "\n",
    "In this section, you're invited to implement a better rollout strategy called _vine_.\n",
    "\n",
    "![img](https://s17.postimg.cc/i90chxgvj/vine.png)\n",
    "\n",
    "In most gym environments, you can actually backtrack by using states. You can find a wrapper that saves/loads states in [the MCTS seminar](https://github.com/yandexdataschool/Practical_RL/blob/master/week10_planning/seminar_MCTS.ipynb).\n",
    "\n",
    "You can read more about TRPO in the [original paper](https://arxiv.org/abs/1502.05477) in section 5.2.\n",
    "\n",
    "The goal here is to implement such rollout policy (we recommend using tree data structure like in the seminar above).\n",
    "Then you can assign cumulative rewards similar to `get_cumulative_rewards`, but for a tree.\n",
    "\n",
    "__bonus task__ - parallelize samples using multiple cores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework option II (10+pts)\n",
    "\n",
    "Let's use TRPO to train evil robots! (pick any of two)\n",
    "* [MuJoCo robots](https://gymnasium.farama.org/environments/mujoco/#mujoco)\n",
    "* [Box2d robot](https://gymnasium.farama.org/environments/box2d/bipedal_walker/)\n",
    "\n",
    "The catch here is that those environments have continuous action spaces.\n",
    "\n",
    "Luckily, TRPO is a policy gradient method, so it's gonna work for any parametric $\\pi_\\theta(a|s)$. We recommend starting with gaussian policy:\n",
    "\n",
    "$$\\pi_\\theta(a|s) = N(\\mu_\\theta(s),\\sigma^2_\\theta(s)) = {1 \\over \\sqrt { 2 \\pi {\\sigma^2}_\\theta(s) } } e^{ (a -\n",
    "\\mu_\\theta(s))^2 \\over 2 {\\sigma^2}_\\theta(s) } $$\n",
    "\n",
    "In the $\\sqrt { 2 \\pi {\\sigma^2}_\\theta(s) }$ clause, $\\pi$ means ~3.1415926, not agent's policy.\n",
    "\n",
    "This essentially means that you will need two output layers:\n",
    "* $\\mu_\\theta(s)$, a dense layer with linear activation\n",
    "* ${\\sigma^2}_\\theta(s)$, a dense layer with activation tf.exp (to make it positive; like rho from bandits)\n",
    "\n",
    "For multidimensional actions, you can use a fully factorized gaussian (basically a vector of gaussians).\n",
    "\n",
    "__Bonus task__: compare the performance of the continuous action space method to action space discretization."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
